{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import time, sys\n",
    "sys.path.append('..')\n",
    "from pb_diff_envs.environment.static.rand_kuka_env import RandKukaEnv\n",
    "\n",
    "import numpy as np\n",
    "from importlib import reload\n",
    "import matplotlib.pyplot as plt\n",
    "import pybullet as p\n",
    "import pb_diff_envs.objects.static.voxel_group as vv; reload(vv)\n",
    "from pb_diff_envs.environment.static.kuka_rand_vgrp import Kuka_VoxelRandGroupList\n",
    "\n",
    "\n",
    "## avg one traj len: , horizon = \n",
    "num_groups = 5\n",
    "num_voxels = 5\n",
    "hr = 0.20 # 0.2\n",
    "vr_x = 0.2; vr_y = 0.2;vr_z = (0., 0.)\n",
    "se_x = 0.5; se_y = 0.5; se_z = (0.3, 0.9) # 0.4 too difficult\n",
    "hExt_range = [hr, hr, hr] # fixed-size obstacle\n",
    "seed = 333\n",
    "\n",
    "## check z\n",
    "start_end = np.array([[-se_x, se_x], [-se_y, se_y], se_z])\n",
    "hExt_range = np.array(hExt_range)\n",
    "void_range = np.array( [(-vr_x, vr_x), (-vr_y, vr_y), vr_z] ) # (0., 0.)\n",
    "\n",
    "trivial = dict(gen_data=True, \n",
    "               samples_per_env=200, gen_num_parts=1, \n",
    "               dataset_url='./dataset/luotest-k7d-ipynb.hdf5',\n",
    "               eval_num_groups=2, eval_num_probs=2, ## dummy placeholder\n",
    "               )\n",
    "mazelist_config = dict(planner_timeout=30, \n",
    "                        interp_density=3,\n",
    "                       min_episode_distance=4*np.pi,)\n",
    "## Manually create a group of envs\n",
    "dvg_group = Kuka_VoxelRandGroupList(RandKukaEnv, num_groups, num_voxels, start_end, void_range, orn_range=None, \\\n",
    " hExt_range=hExt_range, is_static=True, seed=seed, mazelist_config=mazelist_config, **trivial)\n",
    "\n",
    "## Create a Predefined Group of envs\n",
    "# import gym\n",
    "# dvg_group = gym.make('kuka7d-testOnly-v8', load_mazeEnv=False, gen_data=True) # False) # True\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## simple vis of env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "if p.isConnected():\n",
    "    p.disconnect()\n",
    "env = dvg_group.create_single_env(3) # \n",
    "env.unload_env()\n",
    "env.load(GUI=False) # if in a headless server: False; local machine: True\n",
    "p.getContactPoints()\n",
    "p.stepSimulation()\n",
    "# env.robot.get_workspace_observation()\n",
    "plt.imshow(env.render())\n",
    "plt.show()\n",
    "p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)  # Disable the default GUI controls\n",
    "p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING, 1)  # Enable mouse picking for camera control\n",
    "p.addUserDebugLine([0, 0, 0], [1, 0, 0], lineColorRGB=[1, 0, 0], lineWidth=6)\n",
    "p.addUserDebugLine([0, 0, 0], [0, 1, 0], lineColorRGB=[0, 1, 0], lineWidth=6)\n",
    "p.addUserDebugLine([0, 0, 0], [0, 0, 1], lineColorRGB=[0, 0, 1], lineWidth=6)\n",
    "\n",
    "no_while = True\n",
    "\n",
    "while True:\n",
    "    p.stepSimulation()\n",
    "    # p.performCollisionDetection()\n",
    "    c_pts = p.getContactPoints()\n",
    "    # print(p.get)\n",
    "    time.sleep(0.1)\n",
    "    if c_pts is not None and len(c_pts) > 0:\n",
    "        # print(p.getContactPoints())\n",
    "        pass\n",
    "        # time.sleep(2)\n",
    "        # p.resetJointStatesMultiDof(env.robot_id, np.arange(7), [[0.]]*7)\n",
    "    if no_while:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## vis distribution\n",
    "from utils.save_utils import save_scatter_fig\n",
    "# save_scatter_fig (env.wall_locations)\n",
    "points = []\n",
    "for _ in range(2000):\n",
    "    pt, _ = dvg_group.cuboid_grp.sample_xyz_hExt()\n",
    "    points.append(pt[:2])\n",
    "save_scatter_fig(points, None, n_kuka=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env.unload_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''Only for debug, used with a connected monitor'''\n",
    "if False:\n",
    "    print(env)\n",
    "    # p.getOverlappingObjects((0.01, .01, .01))\n",
    "    p.performCollisionDetection()\n",
    "    c_points = p.getContactPoints()\n",
    "    for cp in c_points:\n",
    "        print(cp)\n",
    "    print(p.getNumJoints(2))\n",
    "    print(p.getBasePositionAndOrientation(0))\n",
    "    # p.disconnect()\n",
    "    print(p.getConnectionInfo())\n",
    "    print(p.getAABB(0)) # a bounding box, min point and max point\n",
    "    print(p.getBodyInfo(2))\n",
    "\n",
    "    plt.imshow(env.render())\n",
    "    start_time = time.time()\n",
    "\n",
    "    while True:\n",
    "        time.sleep(0.05)\n",
    "        elapsed_time = time.time() - start_time\n",
    "        if elapsed_time > 10:\n",
    "            start_time = time.time()\n",
    "            env.robot.print_joint_states()\n",
    "            print('no_collision:', env.robot.no_collision())\n",
    "        p.stepSimulation()\n",
    "        break\n",
    "    print(f'env.robot.limits_high {env.robot.limits_high}')\n",
    "    env.robot.print_joint_states()\n",
    "    print(env.robot.item_id)\n",
    "    print(env.robot_id) # list\n",
    "    env.robot.limits_high\n",
    "    # env.robot.limits_low\n",
    "    print(env._get_joint_pose())\n",
    "    print(f'{type(env._get_joint_pose())}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample a problem/episode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from utils.utils import DotDict\n",
    "from tqdm import tqdm\n",
    "env.load(GUI=False) # if in a headless server: False; local machine: True\n",
    "prev_pose = np.array([0.0] * 7)\n",
    "env.robot.set_config(prev_pose)\n",
    "solutions = []\n",
    "n_traj = 2 # 20\n",
    "for i in tqdm(range(n_traj)):\n",
    "    ## this solution is a sparse trajectory\n",
    "    solution, _ = env.sample_1_episode(prev_pose)\n",
    "    solution = np.array(solution)\n",
    "    prev_pose = solution[-1]\n",
    "    solutions.append(solution)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import pybullet as p\n",
    "import pybullet_data\n",
    "import numpy as np\n",
    "from utils.utils import save_gif\n",
    "from IPython.display import HTML\n",
    "import base64\n",
    "from objects.trajectory import WaypointLinearTrajectory, WaypointProportionTrajectory\n",
    "from utils.kuka_utils_luo import SolutionInterp\n",
    "from importlib import reload\n",
    "import utils.robogroup_utils_luo  as rul; reload(rul)\n",
    "from utils.robogroup_utils_luo import robogroup_visualize_traj_luo\n",
    "from utils.kuka_utils_luo import visualize_kuka_traj_luo\n",
    "\n",
    "noise_config = {'std': 0.000}\n",
    "\n",
    "sol_interp = SolutionInterp(density=mazelist_config['interp_density'])\n",
    "sol_i = solutions[0] # which motion trajectory to visualize\n",
    "result = DotDict(solution=sol_i)\n",
    "traj_vis = sol_interp(result.solution,)\n",
    "\n",
    "gifs, ds, vis_dict = visualize_kuka_traj_luo(env, traj_vis, is_debug=False)\n",
    "\n",
    "print(traj_vis.shape)\n",
    "print(result.solution.shape)\n",
    "\n",
    "gif_name = f'./test_bit_kuka.gif'\n",
    "save_gif(gifs, gif_name, duration=ds)\n",
    "b64 = base64.b64encode(open(gif_name, 'rb').read()).decode('ascii')\n",
    "display(HTML(f'<img src=\"data:image/gif;base64,{b64}\" />'))\n",
    "\n",
    "# import mediapy as mpy\n",
    "# gifs = np.array(gifs)\n",
    "# gifs.shape\n",
    "# mpy.show_video(gifs[:, :,:,:3], fps=10)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "vscode": {
   "interpreter": {
    "hash": "c1daed7e4d34e14fbb9bb260207f2e4074aea67593e999acdc484a19561ddc79"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
